import torch
import numpy as np
import scipy.sparse as sp
import torch.nn as nn
import metrics
from parser1 import args
def get_performance(user_pos_test, r, auc, Ks):
    precision, recall, ndcg, hit_ratio = [], [], [], []

    for K in Ks:
        precision.append(metrics.precision_at_k(r, K))
        recall.append(metrics.recall_at_k(r, K, len(user_pos_test)))
        ndcg.append(metrics.ndcg_at_k(r, K))
        hit_ratio.append(metrics.hit_at_k(r, K))

    return {'recall': np.array(recall), 'precision': np.array(precision),
            'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc}

def get_rat(user_feature,item_feature,dis):
    n_item = user_feature.shape[1]
    with torch.no_grad():
        input_ = torch.cat([user_feature.view(-1,args.emb_size*2),item_feature.view(-1,args.emb_size*2)],-1)
        #print(input_[0])
        rat = dis(input_)
        ratting = torch.stack(torch.split(rat,n_item,0)).squeeze()
        #item_index = [torch.topk(r.squeeze(),k=1,dim=0).indices.squeeze().item() for r in ratting]
        item_index = torch.topk(ratting,k=1,dim=1).indices.squeeze()
    return ratting,item_index


def initialize_weights(layer):
    if isinstance(layer, nn.Linear):
        nn.init.xavier_uniform_(layer.weight.data)
        if layer.bias is not None:
            nn.init.zeros_(layer.bias.data)

def matrix_to_tensor(cur_matrix):
    if type(cur_matrix) != sp.coo_matrix:
        cur_matrix = cur_matrix.tocoo()  #
    indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64))  #
    values = torch.from_numpy(cur_matrix.data)  #
    shape = torch.Size(cur_matrix.shape)

    return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda()  #

def graph_drop(graph,keepRate):
    vals = graph._values()
    idxs = graph._indices()
    edgeNum = vals.size()
    mask = ((torch.rand(edgeNum) + keepRate).floor()).type(torch.bool)
    newVals = vals[mask] / keepRate
    newIdxs = idxs[:, mask]
    return torch.sparse.FloatTensor(newIdxs, newVals, graph.shape)